# https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion
from collections import defaultdict
from random import shuffle
from typing import NamedTuple
import torch
from scipy.optimize import linear_sum_assignment


SPECIAL_KEYS = [
	"first_stage_model.decoder.norm_out.weight",
	"first_stage_model.decoder.norm_out.bias",
	"first_stage_model.encoder.norm_out.weight",
	"first_stage_model.encoder.norm_out.bias",
	"model.diffusion_model.out.0.weight",
	"model.diffusion_model.out.0.bias",
]


class PermutationSpec(NamedTuple):
	perm_to_axes: dict
	axes_to_perm: dict


def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
	perm_to_axes = defaultdict(list)
	for wk, axis_perms in axes_to_perm.items():
		for axis, perm in enumerate(axis_perms):
			if perm is not None:
				perm_to_axes[perm].append((wk, axis))
	return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)


def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
	"""Get parameter `k` from `params`, with the permutations applied."""
	try:
		w = params[k]
	except KeyError:
		# If the key is not found in params, return None or handle it as needed
		return None

	try:
		axes_to_perm = ps.axes_to_perm[k]
	except KeyError:
		# If the key is not found in axes_to_perm, return the original parameter
		return w

	for axis, p in enumerate(axes_to_perm):
		# Skip the axis we're trying to permute.
		if axis == except_axis:
			continue

		# None indicates that there is no permutation relevant to that axis.
		if p:
			try:
				w = torch.index_select(w, axis, perm[p].int())
			except KeyError:
				# If the permutation key is not found, continue to the next axis
				continue

	return w


def apply_permutation(ps: PermutationSpec, perm, params):
	"""Apply a `perm` to `params`."""
	return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}


def update_model_a(ps: PermutationSpec, perm, model_a, new_alpha):
	for k in model_a:
		try:
			perm_params = get_permuted_param(
				ps, perm, k, model_a
			)
			model_a[k] = model_a[k] * (1 - new_alpha) + new_alpha * perm_params
		except RuntimeError:  # dealing with pix2pix and inpainting models
			continue
	return model_a


def inner_matching(
	n,
	ps,
	p,
	params_a,
	params_b,
	usefp16,
	progress,
	number,
	linear_sum,
	perm,
	device,
):
	A = torch.zeros((n, n), dtype=torch.float16) if usefp16 else torch.zeros((n, n))
	A = A.to(device)

	for wk, axis in ps.perm_to_axes[p]:
		w_a = params_a[wk]
		w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
		w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(device)
		w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).T.to(device)

		if usefp16:
			w_a = w_a.half().to(device)
			w_b = w_b.half().to(device)

		try:
			A += torch.matmul(w_a, w_b)
		except RuntimeError:
			A += torch.matmul(torch.dequantize(w_a), torch.dequantize(w_b))

	A = A.cpu()
	ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True)
	A = A.to(device)

	assert (torch.tensor(ri) == torch.arange(len(ri))).all()

	eye_tensor = torch.eye(n).to(device)

	oldL = torch.vdot(
		torch.flatten(A).float(), torch.flatten(eye_tensor[perm[p].long()])
	)
	newL = torch.vdot(torch.flatten(A).float(), torch.flatten(eye_tensor[ci, :]))

	if usefp16:
		oldL = oldL.half()
		newL = newL.half()

	if newL - oldL != 0:
		linear_sum += abs((newL - oldL).item())
		number += 1
		print("Merge Rebasin permutation: {p}={newL-oldL}")

	progress = progress or newL > oldL + 1e-12

	perm[p] = torch.Tensor(ci).to(device)

	return linear_sum, number, perm, progress


def weight_matching(
	ps: PermutationSpec,
	params_a,
	params_b,
	max_iter=1,
	init_perm=None,
	usefp16=False,
	device="cpu",
):
	perm_sizes = {
		p: params_a[axes[0][0]].shape[axes[0][1]]
		for p, axes in ps.perm_to_axes.items()
		if axes[0][0] in params_a.keys()
	}
	perm = {}
	perm = (
		{p: torch.arange(n).to(device) for p, n in perm_sizes.items()}
		if init_perm is None
		else init_perm
	)

	linear_sum = 0
	number = 0

	special_layers = ["P_bg324"]
	for _i in range(max_iter):
		progress = False
		shuffle(special_layers)
		for p in special_layers:
			n = perm_sizes[p]
			linear_sum, number, perm, progress = inner_matching(
				n,
				ps,
				p,
				params_a,
				params_b,
				usefp16,
				progress,
				number,
				linear_sum,
				perm,
				device,
			)
		progress = True
		if not progress:
			break

	average = linear_sum / number if number > 0 else 0
	return perm, average
